Skip to content

Add fine_coordinates to model with validation against StaticInputs#999

Merged
frodre merged 6 commits intomainfrom
refactor/add-fine-coords-to-model
Mar 24, 2026
Merged

Add fine_coordinates to model with validation against StaticInputs#999
frodre merged 6 commits intomainfrom
refactor/add-fine-coords-to-model

Conversation

@frodre
Copy link
Copy Markdown
Collaborator

@frodre frodre commented Mar 21, 2026

PR 3/3 for #971 split

Changes

  • Add full_fine_coords: LatLonCoordinates to DiffusionModel as the canonical source of truth for the fine-resolution grid, decoupled from StaticInputs

  • Validate at build time that static_inputs.coords matches full_fine_coords when both are provided

  • Remove dead runtime guard in _get_input_from_coarse for missing static inputs when use_fine_topography=True; this is now enforced at build time

  • Add fine_coordinates_path to CheckpointModelConfig for backwards-compatible loading of old checkpoints that lack stored fine coordinates

  • Add StaticInputs.from_state_backwards_compatible for checkpoint loading that handles both old and new state formats

  • Remove _downscale_coord bandaid from predict.py; get_fine_coords_for_batch now uses full_fine_coords directly

  • Add fine_coords field to PairedGriddedData and thread it through to TrainerConfig.build

  • Tests added

@frodre frodre force-pushed the refactor/static-input-coordinates branch from ff04394 to 665dfb9 Compare March 23, 2026 21:43
Base automatically changed from refactor/static-input-coordinates to main March 23, 2026 23:09
frodre added 4 commits March 24, 2026 09:56
  Here's a summary of the fixes made in this session:

  test_models.py:
  - test_checkpoint_model_build_raises_when_checkpoint_has_static_inputs: added full_fine_coords=static_inputs.coords to _get_diffusion_model
   call
  - Added test_checkpoint_model_build_with_fine_coordinates_path: tests loading an old-format checkpoint (no full_fine_coords in state) using
   fine_coordinates_path
  - Removed unused batch_size variable in test_model_error_cases

  data/config.py:
  - Captured fine_latlon_coords = dataset_fine_subset.subset_latlon_coordinates before the variable gets reassigned to
  BatchItemDatasetAdapter

  predict.py:
  - Fixed EventDownscaler.run() to pass batch (the full BatchData) instead of batch[0] (a BatchItem) to get_fine_coords_for_batch

  test_predict.py:
  - Both test_predictor_runs and test_predictor_renaming: load static_inputs first, then pass full_fine_coords=static_inputs.coords to
  model_config.build()

  inference/test_inference.py:
  - checkpointed_model_config fixture: added full_fine_coords=static_inputs.coords to model_config.build()# Please enter the commit message for your changes. Lines starting
@frodre frodre force-pushed the refactor/add-fine-coords-to-model branch from c19923d to 7af66f7 Compare March 24, 2026 16:57
return ds


def load_coords_from_path(path: str) -> LatLonCoordinates:
Copy link
Copy Markdown
Collaborator Author

@frodre frodre Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-added to use in CheckpointModelConfig. Added the small helper _open_ds_from_path to de-dupe code

@frodre frodre marked this pull request as ready for review March 24, 2026 19:40
@frodre frodre merged commit 2ea6382 into main Mar 24, 2026
7 checks passed
@frodre frodre deleted the refactor/add-fine-coords-to-model branch March 24, 2026 21:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants